{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Converting a Tensorflow Bert model to ONNX\n",
    "\n",
    "This tutorial shows how to convert the original Tensorflow Bert model to ONNX. \n",
    "In this example we fine tune Bert for squad-1.1 on top of [BERT-Base, Uncased](https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip).\n",
    "\n",
    "Since this tutorial cares mostly about the conversion process, we reuse tokenizer and utilities defined in the Bert source tree as much as possible.\n",
    "\n",
    "This should work with all versions supported by the [tensorflow-onnx converter](https://github.com/onnx/tensorflow-onnx), we used the following versions while writing the tutorial:\n",
    "```\n",
    "tensorflow-gpu: 1.13.1\n",
    "onnx: 1.5.1\n",
    "tf2onnx: 1.5.1\n",
    "onnxruntime: 0.4\n",
    "```\n",
    "\n",
    "To make the fine tuning work on my Gtx-1080 gpu, we changed the MAX_SEQ_LENGTH to 256 and used a training batch size of 8."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1 - define some environment variables\n",
    "Before we start, let's set up some variables for where to find things."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "ROOT = os.getcwd()\n",
    "BERT_BASE_DIR = os.path.join(ROOT, 'uncased_L-12_H-768_A-12')\n",
    "SQUAD_DIR = os.path.join(ROOT, 'squad-1.1')\n",
    "OUT = os.path.join(ROOT, 'out')\n",
    "\n",
    "sys.path.append(os.path.join(ROOT, \"bert\"))\n",
    "    \n",
    "os.environ['PYTHONPATH'] = os.path.join(ROOT, \"bert\")\n",
    "os.environ['BERT_BASE_DIR'] = BERT_BASE_DIR\n",
    "os.environ['SQUAD_DIR'] = SQUAD_DIR\n",
    "os.environ['OUT'] = OUT\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = \"0\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2 - clone the Bert github repository"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cloning into 'bert'...\n",
      "remote: Enumerating objects: 329, done.\u001b[K\n",
      "remote: Total 329 (delta 0), reused 0 (delta 0), pack-reused 329\u001b[K\n",
      "Receiving objects: 100% (329/329), 234.38 KiB | 0 bytes/s, done.\n",
      "Resolving deltas: 100% (189/189), done.\n",
      "Checking connectivity... done.\n"
     ]
    }
   ],
   "source": [
    "!git clone https://github.com/google-research/bert bert"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3 - download the pretrained Bert model and squad-1.1 dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!wget -q https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip\n",
    "!unzip uncased_L-12_H-768_A-12.zip\n",
    "\n",
    "!mkdir squad-1.1 out\n",
    "\n",
    "!wget -O squad-1.1/train-v1.1.json  https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json \n",
    "!wget -O squad-1.1/dev-v1.1.json  https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 4 - fine tune the Bert model for squad-1.1\n",
    "This is the same as described in the [Bert repository](https://github.com/google-research/bert). This only needs to be done once.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#\n",
    "# finetune bert for squad-1.1\n",
    "# this will take around 3 hours to complete, and even longer if your device does not have a GPU \n",
    "#\n",
    "\n",
    "!cd bert && \\\n",
    "python run_squad.py \\\n",
    "  --vocab_file=$BERT_BASE_DIR/vocab.txt \\\n",
    "  --bert_config_file=$BERT_BASE_DIR/bert_config.json \\\n",
    "  --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \\\n",
    "  --do_train=True \\\n",
    "  --train_file=$SQUAD_DIR/train-v1.1.json \\\n",
    "  --do_predict=True \\\n",
    "  --predict_file=$SQUAD_DIR/dev-v1.1.json \\\n",
    "  --train_batch_size=8 \\\n",
    "  --learning_rate=3e-5 \\\n",
    "  --num_train_epochs=2.0 \\\n",
    "  --max_seq_length=256 \\\n",
    "  --doc_stride=128 \\\n",
    "  --output_dir=$OUT"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 5 - create the inference graph and save it\n",
    "With a fine-tuned model in hands we want to create the inference graph for it and save it as saved_model format.\n",
    "\n",
    "***We assume that after 2 epochs the checkpoint is model.ckpt-21899 - if the following code does not find it, check the $OUT directory for the higest checkpoint***."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "import collections\n",
    "import json\n",
    "import math\n",
    "import os\n",
    "import random\n",
    "\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "\n",
    "import modeling\n",
    "import optimization\n",
    "import run_squad\n",
    "import tokenization\n",
    "import modeling\n",
    "import optimization\n",
    "import tokenization\n",
    "import run_squad\n",
    "import six\n",
    "\n",
    "#\n",
    "# define some constants used by the model\n",
    "#\n",
    "MAX_SEQ_LENGTH = 256\n",
    "EVAL_BATCH_SIZE = 8\n",
    "N_BEST_SIZE = 20\n",
    "MAX_ANSWER_LENGTH = 30\n",
    "MAX_QUERY_LENGTH = 64\n",
    "DOC_STRIDE = 128\n",
    "\n",
    "VOCAB_FILE = os.path.join(BERT_BASE_DIR, 'vocab.txt')\n",
    "CONFIG_FILE = os.path.join(BERT_BASE_DIR, 'bert_config.json')\n",
    "CHECKPOINT = os.path.join(OUT, 'model.ckpt-21899')\n",
    "\n",
    "tokenizer = tokenization.FullTokenizer(vocab_file=VOCAB_FILE, do_lower_case=True)\n",
    "\n",
    "tf.logging.set_verbosity(\"WARN\")\n",
    "\n",
    "# touch flags\n",
    "FLAGS = tf.flags.FLAGS\n",
    "tf.app.flags.DEFINE_string('f', '', 'kernel')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create the model, run predictions on all data, and save the results to later compare them to the onnxruntime version."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Estimator's model_fn (<function model_fn_builder.<locals>.model_fn at 0x7fb5f1f74f28>) includes params argument, but params are not passed to Estimator.\n",
      "WARNING:tensorflow:eval_on_tpu ignored because use_tpu is False.\n"
     ]
    }
   ],
   "source": [
    "run_config = tf.contrib.tpu.RunConfig(model_dir=OUT, tpu_config=None)\n",
    "\n",
    "model_fn = run_squad.model_fn_builder(\n",
    "    bert_config=modeling.BertConfig.from_json_file(CONFIG_FILE),\n",
    "    init_checkpoint=CHECKPOINT,\n",
    "    learning_rate=0,\n",
    "    num_train_steps=0,\n",
    "    num_warmup_steps=0,\n",
    "    use_tpu=False,\n",
    "    use_one_hot_embeddings=False)\n",
    "\n",
    "estimator = tf.contrib.tpu.TPUEstimator(\n",
    "    use_tpu=False,\n",
    "    model_fn=model_fn,\n",
    "    config=run_config,\n",
    "    predict_batch_size=EVAL_BATCH_SIZE,\n",
    "    export_to_tpu=False)\n",
    "\n",
    "\n",
    "eval_examples = run_squad.read_squad_examples(input_file=os.path.join(SQUAD_DIR, \"dev-v1.1.json\"), is_training=False)\n",
    "eval_writer = run_squad.FeatureWriter(filename=os.path.join(OUT, \"eval.tf_record\"), is_training=False)\n",
    "eval_features = []\n",
    "\n",
    "def append_feature(feature):\n",
    "    eval_features.append(feature)\n",
    "    eval_writer.process_feature(feature)\n",
    "\n",
    "run_squad.convert_examples_to_features(\n",
    "    examples=eval_examples,\n",
    "    tokenizer=tokenizer,\n",
    "    max_seq_length=MAX_SEQ_LENGTH,\n",
    "    doc_stride=DOC_STRIDE,\n",
    "    max_query_length=MAX_QUERY_LENGTH,\n",
    "    is_training=False,\n",
    "    output_fn=append_feature)\n",
    "eval_writer.close()\n",
    "\n",
    "predict_input_fn = run_squad.input_fn_builder(\n",
    "    input_file=eval_writer.filename,\n",
    "    seq_length=MAX_SEQ_LENGTH,\n",
    "    is_training=False,\n",
    "    drop_remainder=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Processing example: 0\n"
     ]
    }
   ],
   "source": [
    "# N is the number of examples we are evaluating. On the CPU this might take a bit.\n",
    "# During development you can set N to some more practical\n",
    "N = len(eval_features)\n",
    "\n",
    "all_results = []\n",
    "for result in estimator.predict(predict_input_fn, yield_single_examples=True):\n",
    "    if len(all_results) % 1000 == 0:\n",
    "        print(\"sample: %d\" % (len(all_results)))\n",
    "    unique_id = int(result[\"unique_ids\"])\n",
    "    start_logits = [float(x) for x in result[\"start_logits\"].flat]\n",
    "    end_logits = [float(x) for x in result[\"end_logits\"].flat]\n",
    "    raw_result = run_squad.RawResult(unique_id=unique_id, start_logits=start_logits, end_logits=end_logits)\n",
    "    all_results.append(raw_result)\n",
    "    if len(all_results) >= N:\n",
    "        break\n",
    "    \n",
    "run_squad.write_predictions(eval_examples[:N], eval_features[:N], all_results,\n",
    "                            N_BEST_SIZE, MAX_ANSWER_LENGTH, True, \n",
    "                            os.path.join(OUT, \"predictions.json\"),\n",
    "                            os.path.join(OUT, \"nbest_predictions.json\"), \n",
    "                            os.path.join(OUT, \"null_odds.json\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "scrolled": true
   },
   "source": [
    "Now let's create the inference graph and save it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Export the model\n",
    "def serving_input_fn():\n",
    "    receiver_tensors = {\n",
    "        'unique_ids': tf.placeholder(dtype=tf.int64, shape=[None], name='unique_ids'),\n",
    "        'input_ids': tf.placeholder(dtype=tf.int64, shape=[None, MAX_SEQ_LENGTH], name='input_ids'),\n",
    "        'input_mask': tf.placeholder(dtype=tf.int64, shape=[None, MAX_SEQ_LENGTH], name='input_mask'),\n",
    "        'segment_ids': tf.placeholder(dtype=tf.int64, shape=[None, MAX_SEQ_LENGTH], name='segment_ids')\n",
    "    }\n",
    "    return tf.estimator.export.ServingInputReceiver(receiver_tensors, receiver_tensors)\n",
    "\n",
    "path = estimator.export_savedmodel(os.path.join(OUT, \"export\"), serving_input_fn)\n",
    "os.environ['LAST_SAVED_MODEL'] = path.decode('utf-8')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 6 - convert to ONNX\n",
    "\n",
    "Convert the model from Tensorflow to ONNX using https://github.com/onnx/tensorflow-onnx."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# install the latest version of tf2onnx if needed\n",
    "!pip install -U tf2onnx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "None\n",
      "2019-06-10 13:19:47.511598: E tensorflow/stream_executor/cuda/cuda_driver.cc:300] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected\n",
      "2019-06-10 13:19:54,043 - INFO - Using tensorflow=1.13.1, onnx=1.5.0, tf2onnx=1.6.0/22481c\n",
      "2019-06-10 13:19:54,043 - INFO - Using opset <onnx, 8>\n",
      "2019-06-10 13:19:57,219 - INFO - \n",
      "2019-06-10 13:19:58,562 - INFO - Optimizing ONNX model\n",
      "2019-06-10 13:19:59,958 - INFO - After optimization: Cast -4 (70->66), Identity -30 (31->1), Transpose -1 (62->61), Unsqueeze -173 (191->18)\n",
      "2019-06-10 13:20:00,031 - INFO - \n",
      "2019-06-10 13:20:00,031 - INFO - Successfully converted TensorFlow model /home/gs/bert/out/export/1560197514 to ONNX\n",
      "2019-06-10 13:20:01,241 - INFO - ONNX model is saved at /home/gs/bert/out/bert.onnx\n"
     ]
    }
   ],
   "source": [
    "# convert model\n",
    "# because we still have a tensorflow session open in this notebook, force the converter to use the CPU.\n",
    "#\n",
    "!CUDA_VISIBLE_DEVICES='' python -m tf2onnx.convert --saved-model $LAST_SAVED_MODEL --output $OUT/bert.onnx --opset 8"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 7 - run the ONNX model under onnxruntime\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's look at the inputs to the ONNX model. The input 'unique_ids' is special and creates an issue in ONNX: the input passed directly to the output and in Tensorflow both have the same name. Because that is not supported in ONNX, the converter creates a new name for the input. We need to use that created name as to remember it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "NodeArg(name='unique_ids_raw_output___9:0', type='tensor(int64)', shape=[None])\n",
      "NodeArg(name='segment_ids:0', type='tensor(int64)', shape=[None, 256])\n",
      "NodeArg(name='input_mask:0', type='tensor(int64)', shape=[None, 256])\n",
      "NodeArg(name='input_ids:0', type='tensor(int64)', shape=[None, 256])\n"
     ]
    }
   ],
   "source": [
    "import onnxruntime as ort\n",
    "\n",
    "sess = ort.InferenceSession(os.path.join(OUT, \"bert.onnx\"))\n",
    "for input_meta in sess.get_inputs():\n",
    "    print(input_meta)\n",
    "\n",
    "# remember the name of unique_id\n",
    "unique_id_name = sess.get_inputs()[0].name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "example: 1\n"
     ]
    }
   ],
   "source": [
    "RawResult = collections.namedtuple(\"RawResult\", [\"unique_id\", \"start_logits\", \"end_logits\"])\n",
    "\n",
    "all_results = []\n",
    "for idx in range(0, N):\n",
    "    item = eval_features[idx]\n",
    "    # this is using batch_size=1\n",
    "    # feed the input data as int64\n",
    "    data = {\"unique_ids_raw_output___9:0\": np.array([item.unique_id], dtype=np.int64),\n",
    "            \"input_ids:0\": np.array([item.input_ids], dtype=np.int64),\n",
    "            \"input_mask:0\": np.array([item.input_mask], dtype=np.int64),\n",
    "            \"segment_ids:0\": np.array([item.segment_ids], dtype=np.int64)}\n",
    "    result = sess.run([\"unique_ids:0\", \"unstack:0\", \"unstack:1\"], data)\n",
    "    unique_id = result[0][0]\n",
    "    start_logits = [float(x) for x in result[1][0].flat]\n",
    "    end_logits = [float(x) for x in result[2][0].flat]\n",
    "    all_results.append(RawResult(unique_id=unique_id, start_logits=start_logits, end_logits=end_logits))\n",
    "    if unique_id % 1000 == 0:\n",
    "        print(\"sample: %d\" % (len(all_results)))\n",
    "    if len(all_results) >= N:\n",
    "        break\n",
    "\n",
    "run_squad.write_predictions(eval_examples[:N], eval_features[:N], all_results,\n",
    "                            N_BEST_SIZE, MAX_ANSWER_LENGTH, True, \n",
    "                            os.path.join(OUT, \"onnx_predictions.json\"),\n",
    "                            os.path.join(OUT, \"onnx_nbest_predictions.json\"), \n",
    "                            os.path.join(OUT, \"onnx_null_odds.json\"))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compare some results between Tensorflow and ONNX:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\r\n",
      "    \"56be4db0acb8001400a502ec\": \"Denver Broncos\",\r\n",
      "    \"56be4db0acb8001400a502ed\": \"Carolina Panthers\",\r\n",
      "    \"56be4db0acb8001400a502ee\": \"Levi's Stadium in the San Francisco Bay Area at Santa Clara, California\",\r\n",
      "    \"56be4db0acb8001400a502ef\": \"Denver Broncos\",\r\n",
      "    \"56be4db0acb8001400a502f0\": \"gold\",\r\n",
      "    \"56be8e613aeaaa14008c90d1\": \"\\\"golden anniversary\",\r\n",
      "    \"56be8e613aeaaa14008c90d2\": \"February 7, 2016\",\r\n",
      "    \"56be8e613aeaaa14008c90d3\": \"American Football Conference\",\r\n",
      "    \"56bea9923aeaaa14008c91b9\": \"\\\"golden anniversary\",\r\n",
      "    \"56bea9923aeaaa14008c91ba\": \"American Football Conference\",\r\n",
      "    \"56bea9923aeaaa14008c91bb\": \"February 7, 2016\",\r\n",
      "    \"56beace93aeaaa14008c91df\": \"Denver Broncos\",\r\n",
      "    \"56beace93aeaaa14008c91e0\": \"Levi's Stadium\",\r\n",
      "    \"56beace93aeaaa14008c91e1\": \"San Francisco\",\r\n",
      "    \"56beace93aeaaa14008c91e2\": \"Super Bowl L\",\r\n",
      "    \"56beace93aeaaa14008c91e3\": \"2015\",\r\n",
      "    \"56bf10f43aeaaa14008c94fd\": \"2015\",\r\n",
      "    \"56bf10f43aeaaa14008c94fe\": \"San Francisco\",\r\n",
      "    \"56bf10f43aeaaa14008c94ff\": \"Levi's Stadium\",\r\n"
     ]
    }
   ],
   "source": [
    "!head -20 $OUT/predictions.json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\r\n",
      "    \"56be4db0acb8001400a502ec\": \"Denver Broncos\",\r\n",
      "    \"56be4db0acb8001400a502ed\": \"Carolina Panthers\",\r\n",
      "    \"56be4db0acb8001400a502ee\": \"Levi's Stadium in the San Francisco Bay Area at Santa Clara, California\",\r\n",
      "    \"56be4db0acb8001400a502ef\": \"Denver Broncos\",\r\n",
      "    \"56be4db0acb8001400a502f0\": \"gold\",\r\n",
      "    \"56be8e613aeaaa14008c90d1\": \"\\\"golden anniversary\",\r\n",
      "    \"56be8e613aeaaa14008c90d2\": \"February 7, 2016\",\r\n",
      "    \"56be8e613aeaaa14008c90d3\": \"American Football Conference\",\r\n",
      "    \"56bea9923aeaaa14008c91b9\": \"\\\"golden anniversary\",\r\n",
      "    \"56bea9923aeaaa14008c91ba\": \"American Football Conference\",\r\n",
      "    \"56bea9923aeaaa14008c91bb\": \"February 7, 2016\",\r\n",
      "    \"56beace93aeaaa14008c91df\": \"Denver Broncos\",\r\n",
      "    \"56beace93aeaaa14008c91e0\": \"Levi's Stadium\",\r\n",
      "    \"56beace93aeaaa14008c91e1\": \"San Francisco\",\r\n",
      "    \"56beace93aeaaa14008c91e2\": \"Super Bowl L\",\r\n",
      "    \"56beace93aeaaa14008c91e3\": \"2015\",\r\n",
      "    \"56bf10f43aeaaa14008c94fd\": \"2015\",\r\n",
      "    \"56bf10f43aeaaa14008c94fe\": \"San Francisco\",\r\n",
      "    \"56bf10f43aeaaa14008c94ff\": \"Levi's Stadium\",\r\n"
     ]
    }
   ],
   "source": [
    "!head -20 $OUT/onnx_predictions.json"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "That was all it takes to convert a relatively complex model from Tensorflow to ONNX. \n",
    "\n",
    "You can find more documentation about tensorflow-onnx [here](https://github.com/onnx/tensorflow-onnx)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
